open Usuba_AST
open Basic_utils
open Utils


let bits_per_reg (prog:prog) (conf:config) : int =
  match conf.archi with
  | Std -> (let main = last prog.nodes in
            let typ  = (List.hd main.p_in).vd_typ in
            match get_type_dir typ with
            | Vslice -> (match get_type_m (List.hd main.p_in).vd_typ with
                         | Mint x -> x
                         | _ -> assert false)
            | _ -> conf.bits_per_reg)
  | _   -> default_bits_per_reg conf.archi

let gen_runtime (orig:prog) (prog:prog) (conf:config) (filename:string) : string =

  let entry = if conf.fdti <> "" then
                List.(Nodes_to_c_fdti.def_to_c (nth prog.nodes (length prog.nodes -1))
                        conf.arr_entry conf)
              else if conf.masked then
                List.(Nodes_to_c_masked.def_to_c (nth prog.nodes (length prog.nodes -1))
                                                 conf.arr_entry conf)
              else if conf.ua_masked then
                List.(Nodes_to_c_ua_masked.def_to_c (nth prog.nodes (length prog.nodes -1))
                                                    conf.arr_entry conf)
              else
                List.(Nodes_to_c.def_to_c (nth prog.nodes (length prog.nodes -1))
                        conf.arr_entry conf) in
  let prog_c = if conf.fdti <> "" then
                 map_no_end (fun x -> Nodes_to_c_fdti.def_to_c x false conf) prog.nodes
               else if conf.masked then
                 map_no_end (fun x -> Nodes_to_c_masked.def_to_c x false conf) prog.nodes
               else if conf.ua_masked then
                 map_no_end (fun x -> Nodes_to_c_ua_masked.def_to_c x false conf) prog.nodes
               else
                 map_no_end (fun x -> Nodes_to_c.def_to_c x false conf) prog.nodes in

  let bench_fun = if conf.gen_bench then
                    if conf.masked then
                      Nodes_to_c_masked.gen_bench (last prog.nodes) conf
                    else if conf.ua_masked then
                      Nodes_to_c_ua_masked.gen_bench (last prog.nodes) conf
                    else
                      Nodes_to_c.gen_bench (last prog.nodes) conf
                  else "" in


Printf.sprintf
"/* This code was generated by Usuba.
   See https://github.com/DadaIsCrazy/usuba.
   From the file \"%s\" (included below). */

#include <stdint.h>

/* Do NOT change the order of those define/include */
%s
#ifndef BITS_PER_REG
#define BITS_PER_REG %d
#endif
/* including the architecture specific .h */
#include \"%s\"

/* auxiliary functions */
%s

/* main function */
%s

/* Additional functions */
%s

/* **************************************************************** */
/*                            Usuba source                          */
/*                                                                  */
/*

%s

*/
 "
  filename
  (if conf.shares <> 1 then Printf.sprintf "#define MASKING_ORDER %d" conf.shares else "")
  (bits_per_reg prog conf)
  (if conf.fdti <> "" then Nodes_to_c_fdti.c_header conf.archi
   else if conf.masked then Nodes_to_c_masked.c_header conf.archi
   else if conf.ua_masked then Nodes_to_c_ua_masked.c_header conf.archi
   else Nodes_to_c.c_header conf.archi)
  (join "\n\n" prog_c)
  entry
  bench_fun
  (Usuba_print.prog_to_str orig)
